{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) Meta Platforms, Inc. and affiliates\n",
    "# All rights reserved.\n",
    "#\n",
    "# This source code is licensed under the license found in the\n",
    "# MIT_LICENSE file in the root directory of this source tree."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MUTOX toxicity classification\n",
    "\n",
    "Mutox lets you score speech and text toxicity using a classifier that can score sonar embeddings. In this notebook, we provide an example of encoding speech and text and classifying that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from pathlib import Path\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "    dtype = torch.float16\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    dtype = torch.float32"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Speech Scoring"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. download some demo audio segments\n",
    "2. create a tsv file to feed to the speech scoring pipeline\n",
    "3. load the model and build the pipeline\n",
    "4. go through the batches in the pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get demo file\n",
    "import urllib.request\n",
    "import tempfile\n",
    "\n",
    "files = [\n",
    "    (\"https://dl.fbaipublicfiles.com/seamless/tests/commonvoice_example_en_clocks.wav\", \"commonvoice_example_en_clocks.wav\"),\n",
    "    (\"https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav\", \"LJ037-0171_sr16k.wav\")\n",
    "]\n",
    "\n",
    "tmpdir = Path(tempfile.mkdtemp())\n",
    "tsv_file = (tmpdir / 'data.tsv')\n",
    "with tsv_file.open('w') as tsv_file_p:\n",
    "    print('path', file=tsv_file_p)\n",
    "    for (uri, name) in files:\n",
    "        dl = tmpdir / name\n",
    "        urllib.request.urlretrieve(uri, dl)\n",
    "        print(dl, file=tsv_file_p)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sonar.inference_pipelines.speech import SpeechInferenceParams\n",
    "from seamless_communication.toxicity.mutox.speech_pipeline import MutoxSpeechClassifierPipeline\n",
    "\n",
    "pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n",
    "    mutox_classifier_name =\"mutox\",\n",
    "    encoder_name=f\"sonar_speech_encoder_eng\",\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = pipeline_builder.build_pipeline(SpeechInferenceParams(\n",
    "    data_file=tsv_file,\n",
    "    audio_root_dir=None,\n",
    "    audio_path_index=0,\n",
    "    target_lang=\"eng\",\n",
    "    batch_size=4,\n",
    "    pad_idx=0,\n",
    "    device=device,\n",
    "    fbank_dtype=torch.float32,\n",
    "    n_parallel=4\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/tmp/tmpqasvhgx6/commonvoice_example_en_clocks.wav\t-42.40079116821289\n",
      "/tmp/tmpqasvhgx6/LJ037-0171_sr16k.wav\t-47.90427780151367\n"
     ]
    }
   ],
   "source": [
    "for batch in pipeline:\n",
    "    ex = batch['audio']\n",
    "    for idx, path in enumerate(ex['path']):\n",
    "        print(str(path), ex[\"data\"][idx].item(), sep=\"\\t\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cleanup tmp dir\n",
    "import shutil\n",
    "shutil.rmtree(tmpdir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text Scoring\n",
    "\n",
    "1. load the sonar text encoder\n",
    "2. load the mutox classifier model\n",
    "3. compute embedding for a sentence\n",
    "4. score this embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the cached checkpoint of mutox. Set `force` to `True` to download again.\n"
     ]
    }
   ],
   "source": [
    "from seamless_communication.toxicity.mutox.loader import load_mutox_model\n",
    "from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n",
    "\n",
    "t2vec_model = TextToEmbeddingModelPipeline(\n",
    "    encoder=\"text_sonar_basic_encoder\",\n",
    "    tokenizer=\"text_sonar_basic_encoder\",\n",
    "    device=device,\n",
    ")\n",
    "text_column='lang_txt'\n",
    "classifier = load_mutox_model(\n",
    "    \"mutox\",\n",
    "    device=device,\n",
    "    dtype=dtype,\n",
    ").eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with torch.inference_mode():\n",
    "    emb = t2vec_model.predict([\"De peur que le pays ne se prostitue et ne se remplisse de crimes.\"], source_lang='fra_Latn')\n",
    "    x = classifier(emb.to(device).half())\n",
    "\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sc_fr2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
